{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch_geometric_temporal.dataset import ChickenpoxDatasetLoader\n",
    "\n",
    "loader = ChickenpoxDatasetLoader()\n",
    "\n",
    "dataset = loader.get_dataset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch_geometric_temporal.dataset import ChickenpoxDatasetLoader\n",
    "from torch_geometric_temporal.signal import temporal_signal_split\n",
    "\n",
    "loader = ChickenpoxDatasetLoader()\n",
    "\n",
    "dataset = loader.get_dataset()\n",
    "\n",
    "train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(20, 4)"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_dataset.__dict__['features'][1].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from torch_geometric_temporal.nn.recurrent import DCRNN\n",
    "\n",
    "class RecurrentGCN(torch.nn.Module):\n",
    "    def __init__(self, node_features):\n",
    "        super(RecurrentGCN, self).__init__()\n",
    "        self.recurrent = DCRNN(node_features, 32, 1)\n",
    "        self.linear = torch.nn.Linear(32, 1)\n",
    "\n",
    "    def forward(self, x, edge_index, edge_weight):\n",
    "        h = self.recurrent(x, edge_index, edge_weight)\n",
    "        h = F.relu(h)\n",
    "        h = self.linear(h)\n",
    "        return h"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/200 [00:00<?, ?it/s]\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "model = RecurrentGCN(node_features = 4)\n",
    "\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n",
    "\n",
    "model.train()\n",
    "\n",
    "for epoch in tqdm(range(200)):\n",
    "    cost = 0\n",
    "    for time, snapshot in enumerate(train_dataset):\n",
    "        y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)\n",
    "        cost = cost + torch.mean((y_hat-snapshot.y)**2)\n",
    "        break\n",
    "    break\n",
    "    cost = cost / (time+1)\n",
    "    cost.backward()\n",
    "    optimizer.step()\n",
    "    optimizer.zero_grad()\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
       "        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
       "        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
       "        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
       "        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
       "        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "snapshot.x.shape\n",
    "y_hat.shape\n",
    "snapshot.edge_index.shape\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pinss",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
